😅

TFRecords (Part 2): Reading and training models with Tfrecords.

In the first part, I showed how to convert a dataset of medical images and its target value into tfrecords. Here, I will be showing how to read tfrecords and also how to train an ML model in tensorflow using tfrecords. For this, I will be used the tfrecords I created from the first part.

Reading Tfrecords

Tfrecords store data in binary format for fast and easy access. To read an example from a tfrecord file, we first need to create a function to parse the example from the file and then use tf.data.TFRecordDataset object and the parsing function to read the file. This is a very easy and straightforward process.

In the first part, we put two features (image and target) from our dataset in the tfrecord files. First, we create a feature_description dictionary to read each feature. The keys in the dictionary must be the same as the keys used to store the features in the tfrecord files. tf.io.FixedLenFeature reads the features and stores them in the data type given. Here, tf.string is used for the image feature since it is a tensor and tf.int64 is used for the target value since it’s an integer.

def parse_tfrecord_fn(example):
    feature_description = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.int64)
    }
    example = tf.io.parse_single_example(example, feature_description)
    example["image"] = tf.io.decode_png(example["image"])
    return example

The parse_tfrecord_fn reads each example and maps the features to the corresponding data type. The tf.io.decode_png function does the opposite of what the tf.io.encode_png function introduced in Part 1 does. It converts the bytelist back to an image tensor using png compression.

After creating the parsing function, it’s fairly easy to read the files. First, use tf.data.TFRecordDataset to read the tfrecord file from the path and call the map method using the parse_tfrecord_fn as the argument.

raw_dataset = tf.data.TFRecordDataset(".../tfrecords/tfrecord_0-1000.tfrec")
parsed_dataset = raw_dataset.map(parse_tfrecord_fn)

Now that the dataset is parsed, to read an example from it, we use the take method, specifying the number of examples we want to read as the argument. This is similar to how we use df.head() in pandas.

for example in parsed_dataset.take(5):
    for key in example.keys():
        print(f"{key}: {type(key)}")

    print(f"Image shape: {example['image'].shape}")
    plt.figure(figsize=(7, 7))
    plt.imshow(example["image"].numpy())
    plt.show()

parsed_dataset.take(5) takes five examples from the dataset. You can access the files using a for loop or by passing it into a list using list(parsed_dataset.take(5))Either way, the examples are stored as a dictionary and the features and values are a key-value pair.


Training a tensorflow model using tfrecords

Training a tensorflow/Keras deep learning model using tfrecords is very easy. First, define the model.

Since the dataset consists of images and their target value, I have chosen the vgg-16 model from tf.keras.applications and I will be applying transfer learning, so the model weights are set to imagenet and trainability of the vgg16 is set to false. Here’s how it’s done using the Keras functional API.

num_classes = 1
input_image = tf.keras.Input(shape=(224, 224, 3), name='image')

# Load the VGG16 model
vgg16 = tf.keras.applications.VGG16(weights='imagenet',
                                    include_top=False,
                                    input_shape=(224, 224, 3))

vgg16.trainable = False

x = vgg16(input_image)

y = tf.keras.layers.Flatten()(x)

z = tf.keras.layers.Dense(128, activation='relu')(y)

output = tf.keras.layers.Dense(num_classes, activation='sigmoid')(z)

# Create the model
model = tf.keras.Model(inputs=[input_image], outputs=output)

model.summary()

Line 2 of the code above defines the Input of the model. Here, we are using a tensor of shape (224, 224, 3) which is the size of the image and assigning the input the name “image”. The name would help tensorflow know what feature from the tfrecord file to use as input. This is particularly helpful when you have more than one feature.

Next, we define a parsing function that will convert each example in a way that tensorflow will understand when trying to fit the model.

def parse_example(example):
    feature_description = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.int64)
    }
    example = tf.io.parse_single_example(example, feature_description)
    example["image"] = tf.io.decode_png(example["image"])
    X = dict()
    X['image'] = tf.image.grayscale_to_rgb(example['image'])
    y = example['target']
    return X, y

Here, the function returns a tuple (X, y) where X is the feature and y the target. The first 6 lines of the function are the same as the one we used earlier. The main difference is the definition of a dictionary assigned to X . The next line takes the image tensor from example['image'] and converts it from shape (224, 224, 1) to (224, 224, 3) .( You can read the documentation of how tf.image.grayscale_to_rgb works for more details on that). The converted tensor is assigned to X['image'] . X is the feature dictionary we want to return along with the target and image is used as the key here to match the name of the Input defined in the model above.

Now, all we have to do is read the files. Since we created more than one tfrecord file. If we want to train the model on all of the data available, we have to read all the files. Thankfully, tf.data.TFRecordDataset accepts both a single file path and a list of file paths as input arguments.

tf_files = glob.glob('/kaggle/input/notebook9fa8840645/tfrecords/*.tfrec')
dataset = tf.data.TFRecordDataset(tf_files)
dataset = dataset.map(parse_example)

This time around, when calling the map method, we use the use parsing function parse_example as input.

Now, to train the model using the dataset, we call the batch method on the dataset and pass the output as the argument to model.fit

dataset = dataset.shuffle(buffer_size=1024).batch(32)

model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy'])

# Train the model
model.fit(dataset, epochs=10, validation_data=val_dataset)

Here, I called the shuffle method first to randomly select 1024 examples and then using the batch method to pick 32 of those.


How about a Train/Val/Test Split?

tf.data.TFRecordDataset does not provide a straightforward way to perform a train-test split. But it’s pretty much straightforward if you know the number of examples in the dataset. Tensorflow also provides a function you can use to get the size but this might not work. You can check if the size of the dataset is known using the following block of code.

cardinality = tf.data.experimental.cardinality(dataset)
print((cardinality == tf.data.experimental.UNKNOWN_CARDINALITY).numpy())

If this returns True then the size of the dataset is not known. If it returns False then it is known and you can get the size using tf.data.experimental.cardinality or dataset.cardinality.

Using the function below, you can create a train-val-test split.

def train_val_test_split(dataset,
                         dataset_size,
                         train_split=0.8,
                         val_split=0.1,
                         test_split=0.1,
                         shuffle=True,
                         shuffle_size=10000):
    assert (train_split + test_split + val_split) == 1

    if shuffle:
      # Specify seed to always have the same split distribution between runs
        dataset = dataset.shuffle(shuffle_size, seed=12)

    train_size = int(train_split * dataset_size)
    val_size = int(val_split * dataset_size)

    train_dataset = dataset.take(train_size)
    val_dataset = dataset.skip(train_size).take(val_size)
    test_dataset = dataset.skip(train_size).skip(val_size)

    return train_dataset, val_dataset, test_dataset

by sodipe🌚 on February 23, 2023.